function [accuracy,Time] = GE_svm_contrast_conditions_perm_diftest(epochfile,train_conditionA,train_conditionB,test_conditionA,test_conditionB,param)
% function [accuracy,Time] = svm_contrast_conditions_perm(subject,conditionsA,conditionsB,param)
%
% Apply SVM classifier on MEG trials with supervised learning. Uses trial
% subaverages and permutations
%
%use if training and testing on different data
%
% Example:
%   %parameters
%   param.brainstorm_db = 'D:\MYPROJECTS11\project_rapid_images_Molly_Carl\Data\HagmannRSVP\data\';
%   param.data_type = 'MEG';
%   param.smooth_size = 15;
%   param.num_permutations = 30;
%   param.trial_bin_size = 5;
%
%
% Author: Dimitrios Pantazis
% modified by SA 2016-06-15
% modified by SA 2018-03-30 for UAG
%    -  ROI source waveforms are assumed to be collected into "epochfile"
%       in parameters "epoch_data" and "condition_index"

%initialize
num_permutations = param.num_permutations;
trial_bin_size = param.trial_bin_size;
f_lowpass = param.f_lowpass;
brainstorm_db = param.brainstorm_db;
data_type = param.data_type;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% NOTE: instead of separate function "load_trial", the data is read in 
% directly here..
%
%%load data (force equal number of trials per condition)
%%[trial,Time] = load_trials(brainstorm_db,subject,conditionsA,conditionsB,data_type,f_lowpass);
%[trial,Time] = UAG_load_trials(epochfile,conditionsA,conditionsB);

load(epochfile);

train_setA = [str2num(train_conditionA)];
train_setB = [str2num(train_conditionB)];
test_setA = [str2num(test_conditionA)];
test_setB = [str2num(test_conditionB)];

n_chs = length(epoch_data(1,:,1));
n_epochs = length(epoch_data(1,1,:));
i_keepA_train = 0;
i_keepB_train = 0;
i_keepA_test = 0;
i_keepB_test = 0;
for i_epoch = 1:n_epochs
    if (ismember(condition_index(i_epoch),train_setA))
        i_keepA_train = i_keepA_train + 1;
        trial{1}{i_keepA_train} = epoch_data(:,:,i_epoch)';
    elseif (ismember(condition_index(i_epoch),train_setB))
        i_keepB_train = i_keepB_train + 1;
        trial{2}{i_keepB_train} = epoch_data(:,:,i_epoch)';
    elseif (ismember(condition_index(i_epoch),test_setA))
        i_keepA_test = i_keepA_test + 1;
        trial{3}{i_keepA_test} = epoch_data(:,:,i_epoch)';
    elseif (ismember(condition_index(i_epoch),test_setB))
        i_keepB_test = i_keepB_test + 1;
        trial{4}{i_keepB_test} = epoch_data(:,:,i_epoch)';
        %trial {1} and trial {2}=train data, trial{3} and trial {4}= test
        %data
    end % if
end % i_epoch

% % For testing only, NOT FINAL:
% for i_epoch = 1:n_epochs
%     if (ismember(condition_index(i_epoch),train_setA))
%         i_keepA_train = i_keepA_train + 1;
%         trial{1}{i_keepA_train} = epoch_data(:,:,i_epoch)';
%         i_keepA_test = i_keepA_test + 1;
%         trial{3}{i_keepA_test} = epoch_data(:,:,i_epoch)';
%     elseif (ismember(condition_index(i_epoch),train_setB))
%         i_keepB_train = i_keepB_train + 1;
%         trial{2}{i_keepB_train} = epoch_data(:,:,i_epoch)';
%         i_keepB_test = i_keepB_test + 1;
%         trial{4}{i_keepB_test} = epoch_data(:,:,i_epoch)';
%         %trial {1} and trial {2}=train data, trial{3} and trial {4}= test
%         %data
%     end % if
% end % i_epoch
  
Time = -100.0:1.0:1000.0; % fixed window   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Time = times; % "times" were loaded from epochfile


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

ntimes = size(trial{1}{1},2);
ntrials_train = min([length(trial{1}) length(trial{2})]);
ntrials_test = min([length(trial{3}) length(trial{4})]);
nchannels = size(trial{1}{1},1);

%%re-reference:average channel (only for EEG)
% if strfind(data_type,'EEG')
%     channelfile = [brainstorm_db subject '/@default_study/channel_vectorview306.mat'];
%     load(channelfile)
%     ndx = find_channels(Channel,'EEG');
%     for i = 1:2 %for both groups
%         for j = 1:ntrials_train
%             trial{i}{j}(ndx,:) = trial{i}{j}(ndx,:) - repmat(mean(trial{i}{j}(ndx,:)),length(ndx),1);
%         end
%     end
% end

%correct for baseline std
tndx = Time<0;
for i = 1:2 %for both groups
    for j = 1:ntrials_train
        trial{i}{j} = trial{i}{j} ./ repmat( std(trial{i}{j}(:,tndx)')',1,ntimes );
    end
    for j = 1:ntrials_test
        trial{i+2}{j} = trial{i+2}{j} ./ repmat( std(trial{i+2}{j}(:,tndx)')',1,ntimes );
    end
end

%get labels for train and test groups
nsamples_train = floor(ntrials_train/trial_bin_size); %number of bins
nsamples_test = floor(ntrials_test/trial_bin_size); %number of bins
samples_train = reshape([1:nsamples_train*trial_bin_size],trial_bin_size,nsamples_train)';
samples_test = reshape([1:nsamples_test*trial_bin_size],trial_bin_size,nsamples_test)';

train_label = [ones(1,nsamples_train) 2*ones(1,nsamples_train)];
%test_label = [ones(1,nsamples_test) 2*ones(1,nsamples_test)];
test_label=[1 2];


%perform decoding
%matlabpool(2);
Accuracy = zeros(num_permutations,ntimes);
%train_trialsA = zeros(nsamples-1,nchannels,ntimes);
%train_trialsB = zeros(nsamples-1,nchannels,ntimes);
for p = 1:num_permutations %randomly organizing trials into bins
    
    if ~rem(p,10)
        disp(['p = ' num2str(p)]);
    end

    %randomize samples
    perm_ndx_train = randperm(nsamples_train*trial_bin_size);
    perm_samples_train = perm_ndx_train(samples_train);
    perm_ndx_test = randperm(nsamples_test*trial_bin_size);
    perm_samples_test = perm_ndx_test(samples_test);
    
    %create samples
    %train_trialsA = UAG_average_structure2(trial{1}(perm_samples(1:nsamples-1,:)));
    %train_trialsB = UAG_average_structure2(trial{2}(perm_samples(1:nsamples-1,:)));
    train_trialsA = average_structure2(trial{1}(perm_samples_train(1:nsamples_train,:)));
    train_trialsB = average_structure2(trial{2}(perm_samples_train(1:nsamples_train,:)));
    train_trials = [train_trialsA;train_trialsB];
    
    test_trialsA = average_structure(trial{3}(perm_samples_test(1:nsamples_test,:)));
    test_trialsB = average_structure(trial{4}(perm_samples_test(1:nsamples_test,:)));
    test_trials = reshape([test_trialsA test_trialsB],[nchannels,ntimes,2]);
    test_trials = permute(test_trials,[3 1 2]);
    
    for tndx = 1:ntimes
        
        model = svmtrain(train_trials(:,:,tndx), train_label','method','LS');
        group = svmclassify(model,test_trials(:,:,tndx));
        %accuracy = sum([test_label - group']==0)/2 * 100;
        accuracy = sum((test_label - group')==0)/2 * 100; %changed by adriana bc of "unecessary use of brackets"
        Accuracy(p,tndx) = accuracy;
        
        %model = svmtrain(train_label',train_trials(:,:,tndx),'-s 0 -t 0 -q');
        %[predicted_label, accuracy, decision_values] = svmpredict(test_label', test_trials(:,:,tndx), model);
        %Accuracy(p,tndx) = accuracy(1);
    end

end
%matlabpool close

%save and plot results
accuracy = mean(Accuracy,1);
